{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76350774-70f0-47d7-b790-efd515d84b8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from flwr_datasets.partitioner import GroupedNaturalIdPartitioner\n",
    "from flwr_datasets.visualization import plot_label_distributions\n",
    "import matplotlib.pyplot as plt\n",
    "from datasets import load_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f367dfb9-c0f0-4098-b999-9bbd00d0cd46",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load train partition of SpeechCommands\n",
    "sc = load_dataset(\"speech_commands\", \"v0.02\", split=\"train\", token=False)\n",
    "\n",
    "# Use the \"Grouped partitioner\" from FlowerDatasets to construct groups of 30 unique speaker ids\n",
    "partitioner = GroupedNaturalIdPartitioner(partition_by=\"speaker_id\", group_size=30)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d457b46-d649-4b41-b950-f56891ab8961",
   "metadata": {},
   "source": [
    "### Removing _silence_ clips"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a299c2ed-f7be-48ec-92a6-aa0ba2c992b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Remove the silence audio clips (the dataset comes with 5 long audio clips. we don't want to show these in the plot below)\n",
    "# Those silence audio clips are the entries in the dataset with `speaker_id`=None. Let's remove them\n",
    "# At training time, each client with get 10% new data samples containing 1s-long silence clips\n",
    "def filter_none_speaker(example):\n",
    "    return example[\"speaker_id\"] is not None\n",
    "\n",
    "\n",
    "filtered_dataset = sc.filter(filter_none_speaker)\n",
    "\n",
    "# Apply dataset to partitioner\n",
    "partitioner.dataset = filtered_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70e31a9c-e446-4407-b774-c48d7e6edf88",
   "metadata": {},
   "source": [
    "### Making a plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a78625ab-f054-4582-9b59-9a057102f434",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axis = plt.subplots(figsize=(16, 6))\n",
    "fig, ax, df = plot_label_distributions(\n",
    "    partitioner,\n",
    "    axis=axis,\n",
    "    label_name=\"label\",\n",
    "    plot_type=\"bar\",\n",
    "    size_unit=\"absolute\",\n",
    "    partition_id_axis=\"x\",\n",
    "    legend=True,\n",
    "    verbose_labels=True,\n",
    "    title=\"Per Partition Labels Distribution\",\n",
    "    legend_kwargs={\"ncols\": 2, \"bbox_to_anchor\": (1.05, 0.5)},\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49a9b4c2-e291-4ae4-86a2-051265640ed8",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig.savefig(\"whisper_flower_data.png\", format=\"png\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "af15f634-daf5-4709-b2a6-c98dfb8db463",
   "metadata": {},
   "source": [
    "### Process dataset into 12 classes\n",
    "\n",
    "To go from 35 classes into 12, we need to apply the following cahnges:\n",
    "- all audio clips that had the `is_unknown` set, will be assigned the same \"target\" label `11`\n",
    "- Silence audio clips will assigned label `10`\n",
    "\n",
    "We achieve this 35:12 mapping by means of the function below (similar to the one used in the code)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5092c826-f085-499a-acae-5ffcc0442757",
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_dataset(batch):\n",
    "    data = {}\n",
    "    # All unknown keywords are assigned label 11. The silence clips get assigned label 10\n",
    "    # In this way we have 12 classes with labels 0-11\n",
    "    data[\"targets\"] = (\n",
    "        11 if batch[\"is_unknown\"] else (10 if batch[\"label\"] == 35 else batch[\"label\"])\n",
    "    )\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03131a7e-4fe0-4271-9078-a8314734b544",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_12cls = filtered_dataset.map(prepare_dataset, num_proc=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ba0c944-703f-46be-8886-34ab46fa6ba1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Re-construct the partitioner and apply the filtered dataset\n",
    "partitioner = GroupedNaturalIdPartitioner(partition_by=\"speaker_id\", group_size=30)\n",
    "partitioner.dataset = dataset_12cls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e8b319c-1056-4b68-9082-4c0057f50d2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate the plot again, this time using the new \"targets\" key\n",
    "fig, axis = plt.subplots(figsize=(16, 6))\n",
    "fig, ax, df = plot_label_distributions(\n",
    "    partitioner,\n",
    "    axis=axis,\n",
    "    label_name=\"targets\",\n",
    "    plot_type=\"bar\",\n",
    "    size_unit=\"absolute\",\n",
    "    partition_id_axis=\"x\",\n",
    "    legend=True,\n",
    "    verbose_labels=False,\n",
    "    title=\"Per Partition Labels Distribution\",\n",
    "    legend_kwargs={\"ncols\": 2, \"bbox_to_anchor\": (1.0, 0.5)},\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca8743ab-2f21-49f2-808c-74739f9d97aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# classes 0-9 correspond to keywords: 'yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off'\n",
    "# Class 10 is 'silence' and class 11 is 'other' (combined remaining classes from the 35-class original representation)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
